"""

Authors: Pratik Bhatu.

Copyright:
Copyright (c) 2021 Microsoft Research
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

"""
import tensorflow as tf
import numpy as np
import argparse

from tf_graph_io import *
from tf_graph_trans import *

import os.path
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), "..", "TFCompiler"))
import DumpTFMtData

from os import path


def check_operation_exists(graph, tensor_name):
    op_list = [i.name for i in graph.get_operations()]
    return tensor_name in op_list


def compile(
    model_fname, input_t_name, output_t_name, scaling_factor, save_weights, input_shape
):
    if not model_fname.endswith(".pb"):
        sys.exit("Please supply a valid tensorflow protobuf model (.pb extension)")
    elif not "mpc_processed_" in model_fname:
        sys.exit(
            """Please process model using preprocess_frozen_tf_graph.py.
This will optimise it and generate a new .pb with mpc_processed prefix.
Use that with this script."""
        )
    else:
        model_name = os.path.basename(model_fname)[:-3]

    print("Loading processed tf graph ", model_fname)
    graph = load_pb(model_fname)

    if not check_operation_exists(graph, output_t_name):
        sys.exit(output_t_name + " output does not exist in the graph")
    output_t = graph.get_operation_by_name(output_t_name).outputs[0]

    if input_t_name != "":
        if not check_operation_exists(graph, input_t_name):
            sys.exit(input_t_name + " input does not exist in the graph")

        input_t = graph.get_operation_by_name(input_t_name).outputs[0]

        # Generate random tensor as input
        # scalar input
        if input_t.shape.dims == None:
            inp_shape = []
        else:
            inp_shape = input_t.shape.as_list()
            if None in inp_shape:
                if input_shape == []:
                    sys.exit(
                        "Please supply shape for the input tensor as it is parametric (? dim) for this model. See --help."
                    )
                else:
                    inp_shape = input_shape
        rand_inp_t = np.zeros(inp_shape)

        feed_dict = {input_t: rand_inp_t}
    else:
        # We can collect all placeholder nodes as inputs to the model
        inputs = [i for i in graph.get_operations() if i.type == "Placeholder"]
        feed_dict = {}
        for op in inputs:
            input_t = op.outputs[0]
            if input_t.shape.dims == None:
                inp_shape = []
            else:
                inp_shape = input_t.shape.as_list()
                if None in inp_shape:
                    sys.exit(
                        "Please supply input names and their shapes for the input tensor as it is parametric (? dim) for this model. See --help."
                    )
            rand_inp_t = np.zeros(inp_shape)
            feed_dict[input_t] = rand_inp_t

    with graph.as_default():
        with tf.Session() as sess:
            # Run initializers generated by preprocessing
            if check_operation_exists(graph, "init_constvars"):
                sess.run(graph.get_operation_by_name("init_constvars"))
            else:
                sess.run(tf.global_variables_initializer())
            # Dump sizeInfo, graphDef mtdata and weight dump in model folder.
            model_dir = os.path.realpath(os.path.dirname(model_fname))
            os.chdir(model_dir)
            optimized_graph_def = DumpTFMtData.save_graph_metadata(
                output_t, sess, feed_dict
            )
            print("Model compilation done.")
            trainVarsName = [
                node.name
                for node in optimized_graph_def.node
                if node.op == "VariableV2" or node.op == "Variable"
            ]
            trainVars = list(
                map(
                    lambda x: tf.get_default_graph()
                    .get_operation_by_name(x)
                    .outputs[0],
                    trainVarsName,
                )
            )
            if save_weights:
                DumpTFMtData.updateWeightsForBN(optimized_graph_def, sess)
                weights_fname = (
                    model_name[len("mpc_processed_") :]
                    + "_input_weights_fixedpt_scale_"
                    + str(scaling_factor)
                    + ".inp"
                )
                print(
                    "\nDumping model weights in ",
                    weights_fname,
                    ". These are to be used as input for party which owns the model\n",
                )
                DumpTFMtData.dumpTrainedWeightsInt(
                    sess, trainVars, weights_fname, scaling_factor, "w"
                )


def boolean_string(s):
    if s not in {"False", "True"}:
        raise ValueError("Not a valid boolean string")
    return s == "True"


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--modelName",
        required=True,
        type=str,
        help="Name of processed tensorflow model (mpc_processed*.pb)",
    )
    parser.add_argument(
        "--inputTensorName",
        type=str,
        default="",
        help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)",
    )
    parser.add_argument(
        "--outputTensorName",
        required=True,
        type=str,
        help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)",
    )
    parser.add_argument("--sf", default=12, type=int, help="scaling factor (int)")
    parser.add_argument(
        "--saveWeights",
        type=boolean_string,
        default=False,
        help="Dump model weights in fixedpt {True/False}",
    )
    parser.add_argument(
        "--inputTensorShape",
        type=str,
        default="",
        help='Comma separated list of shape for input tensor. eg: "2,245,234,3"',
    )
    args = parser.parse_args()
    return args


def get_shape_list(shape_string):
    if shape_string == "":
        return []
    return [int(i) for i in shape_string.split(",")]


if __name__ == "__main__":
    args = parse_args()
    shape_list = get_shape_list(args.inputTensorShape)
    compile(
        args.modelName,
        args.inputTensorName,
        args.outputTensorName,
        args.sf,
        args.saveWeights,
        shape_list,
    )
